{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "139cf9cb-1268-47e6-b495-489f99b2cc29",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdf0f8a8-b95c-4335-a5e1-a2d6f4a31fed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "\n",
    "from app.vjepa_droid.transforms import make_transforms\n",
    "from utils.mpc_utils import (\n",
    "    compute_new_pose,\n",
    "    poses_to_diff\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1baab712-266c-4822-b229-673ef728abb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize VJEPA 2-AC model\n",
    "encoder, predictor = torch.hub.load(\"facebookresearch/vjepa2\", \"vjepa2_ac_vit_giant\")\n",
    "\n",
    "# Initialize transform\n",
    "crop_size = 256\n",
    "tokens_per_frame = int((crop_size // encoder.patch_size) ** 2)\n",
    "transform = make_transforms(\n",
    "    random_horizontal_flip=False,\n",
    "    random_resize_aspect_ratio=(1., 1.),\n",
    "    random_resize_scale=(1., 1.),\n",
    "    reprob=0.,\n",
    "    auto_augment=False,\n",
    "    motion_shift=False,\n",
    "    crop_size=crop_size,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3b910c0-90cc-4d95-8ed9-dabd5304741b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load robot trajectory\n",
    "\n",
    "play_in_reverse = False  # Use this FLAG to try loading the trajectory backwards, and see how the energy landscape changes\n",
    "\n",
    "trajectory = np.load(\"franka_example_traj.npz\")\n",
    "np_clips = trajectory[\"observations\"]\n",
    "np_states = trajectory[\"states\"]\n",
    "if play_in_reverse:\n",
    "    np_clips = trajectory[\"observations\"][:, ::-1].copy()\n",
    "    np_states = trajectory[\"states\"][:, ::-1].copy()\n",
    "np_actions = np.expand_dims(poses_to_diff(np_states[0, 0], np_states[0, 1]), axis=(0, 1))\n",
    "\n",
    "# Convert trajectory to torch tensors\n",
    "clips = transform(np_clips[0]).unsqueeze(0)\n",
    "states = torch.tensor(np_states)\n",
    "actions = torch.tensor(np_actions)\n",
    "print(f\"clips: {clips.shape}; states: {states.shape}; actions: {actions.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fbc96f4-e310-420d-9230-09edddbe91ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize loaded video frames from traj\n",
    "\n",
    "T = len(np_clips[0])\n",
    "plt.figure(figsize=(20, 3))\n",
    "_ = plt.imshow(np.transpose(np_clips[0], (1, 0, 2, 3)).reshape(256, 256 * T, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01dc4f05-14b9-4e06-bdc0-cf61bd88e68b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_target(c, normalize_reps=True):\n",
    "    B, C, T, H, W = c.size()\n",
    "    c = c.permute(0, 2, 1, 3, 4).flatten(0, 1).unsqueeze(2).repeat(1, 1, 2, 1, 1)\n",
    "    h = encoder(c)\n",
    "    h = h.view(B, T, -1, h.size(-1)).flatten(1, 2)\n",
    "    if normalize_reps:\n",
    "        h = F.layer_norm(h, (h.size(-1),))\n",
    "    return h\n",
    "\n",
    "\n",
    "def forward_actions(z, nsamples, grid_size=0.075, normalize_reps=True, action_repeat=1):\n",
    "\n",
    "    def make_action_grid(grid_size=grid_size):\n",
    "        action_samples = []\n",
    "        for da in np.linspace(-grid_size, grid_size, nsamples):\n",
    "            for db in np.linspace(-grid_size, grid_size, nsamples):\n",
    "                for dc in np.linspace(-grid_size, grid_size, nsamples):\n",
    "                    action_samples += [torch.tensor([da, db, dc, 0, 0, 0, 0], device=z.device, dtype=z.dtype)]\n",
    "        return torch.stack(action_samples, dim=0).unsqueeze(1)\n",
    "\n",
    "    # Sample grid of actions\n",
    "    action_samples = make_action_grid()\n",
    "    print(f\"Sampled grid of actions; num actions = {len(action_samples)}\")\n",
    "\n",
    "    def step_predictor(_z, _a, _s):\n",
    "        _z = predictor(_z, _a, _s)[:, -tokens_per_frame:]\n",
    "        if normalize_reps:\n",
    "            _z = F.layer_norm(_z, (_z.size(-1),))\n",
    "        _s = compute_new_pose(_s[:, -1:], _a[:, -1:])\n",
    "        return _z, _s\n",
    "\n",
    "    # Context frame rep and context pose\n",
    "    z_hat = z[:, :tokens_per_frame].repeat(int(nsamples**3), 1, 1)  # [S, N, D]\n",
    "    s_hat = states[:, :1].repeat((int(nsamples**3), 1, 1))  # [S, 1, 7]\n",
    "    a_hat = action_samples  # [S, 1, 7]\n",
    "\n",
    "    for _ in range(action_repeat):\n",
    "        _z, _s = step_predictor(z_hat, a_hat, s_hat)\n",
    "        z_hat = torch.cat([z_hat, _z], dim=1)\n",
    "        s_hat = torch.cat([s_hat, _s], dim=1)\n",
    "        a_hat = torch.cat([a_hat, action_samples], dim=1)\n",
    "\n",
    "    return z_hat, s_hat, a_hat\n",
    "\n",
    "def loss_fn(z, h):\n",
    "    z, h = z[:, -tokens_per_frame:], h[:, -tokens_per_frame:]\n",
    "    loss = torch.abs(z - h)  # [B, N, D]\n",
    "    loss = torch.mean(loss, dim=[1, 2])\n",
    "    return loss.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01c2e7c4-8cd6-454c-89e9-f060bf4978cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute energy for cartesian action grid of size (nsample x nsamples x nsamples)\n",
    "nsamples = 5\n",
    "grid_size = 0.075\n",
    "with torch.no_grad():\n",
    "    h = forward_target(clips)\n",
    "    z_hat, s_hat, a_hat = forward_actions(h, nsamples=nsamples, grid_size=grid_size)\n",
    "    loss = loss_fn(z_hat, h)  # jepa prediction loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79d2b9a2-0989-4432-87c3-7e0a373a1c57",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the energy\n",
    "\n",
    "plot_data = []\n",
    "for b, v in enumerate(loss):\n",
    "    plot_data.append((\n",
    "        a_hat[b, :-1, 0].sum(),\n",
    "        a_hat[b, :-1, 1].sum(),\n",
    "        a_hat[b, :-1, 2].sum(),\n",
    "        v,\n",
    "    ))\n",
    "\n",
    "delta_x = [d[0] for d in plot_data]\n",
    "delta_y = [d[1] for d in plot_data]\n",
    "delta_z = [d[2] for d in plot_data]\n",
    "energy = [d[3] for d in plot_data]\n",
    "\n",
    "gt_x = actions[0, 0, 0]\n",
    "gt_y = actions[0, 0, 1]\n",
    "gt_z = actions[0, 0, 2]\n",
    "\n",
    "# Create the 2D histogram\n",
    "heatmap, xedges, yedges = np.histogram2d(delta_x, delta_z, weights=energy, bins=nsamples)\n",
    "\n",
    "# Set axis labels\n",
    "plt.xlabel(\"Action Delta x\")\n",
    "plt.ylabel(\"Action Delta z\")\n",
    "plt.title(f\"Energy Landscape\")\n",
    "\n",
    "# Display the heatmap\n",
    "print(f\"Ground truth action (x,y,z) = ({gt_x:.2f},{gt_y:.2f},{gt_z:.2f})\")\n",
    "_ = plt.imshow(heatmap.T, origin=\"lower\", extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], cmap=\"viridis\")\n",
    "_ = plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30525f3c-5479-4cfb-b2eb-6944b523e061",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the optimal action using MPC\n",
    "from utils.world_model_wrapper import WorldModel\n",
    "\n",
    "world_model = WorldModel(\n",
    "    encoder=encoder,\n",
    "    predictor=predictor,\n",
    "    tokens_per_frame=tokens_per_frame,\n",
    "    transform=transform,\n",
    "    # Doing very few CEM iterations with very few samples just to run efficiently on CPU...\n",
    "    # ... increase cem_steps and samples for more accurate optimization of energy landscape\n",
    "    mpc_args={\n",
    "        \"rollout\": 2,\n",
    "        \"samples\": 25,\n",
    "        \"topk\": 10,\n",
    "        \"cem_steps\": 2,\n",
    "        \"momentum_mean\": 0.15,\n",
    "        \"momentum_mean_gripper\": 0.15,\n",
    "        \"momentum_std\": 0.75,\n",
    "        \"momentum_std_gripper\": 0.15,\n",
    "        \"maxnorm\": 0.075,\n",
    "        \"verbose\": True\n",
    "    },\n",
    "    normalize_reps=True,\n",
    "    device=\"cpu\"\n",
    ")\n",
    "\n",
    "with torch.no_grad():\n",
    "    h = forward_target(clips)\n",
    "    z_n, z_goal = h[:, :tokens_per_frame], h[:, -tokens_per_frame:]\n",
    "    s_n = states[:, :1]\n",
    "    print(f\"Starting planning using Cross-Entropy Method...\")\n",
    "    actions = world_model.infer_next_action(z_n, s_n, z_goal).cpu().numpy()\n",
    "\n",
    "print(f\"Actions returned by planning with CEM (x,y,z) = ({actions[0, 0]:.2f},{actions[0, 1]:.2f} {actions[0, 2]:.2f})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfb8bfa9-0db1-4096-9bc9-6c2ae7ef5b86",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
